{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Written by, \n",
    "Sriram Ravindran, sriram@ucsd.edu\n",
    "\n",
    "Original paper - https://arxiv.org/abs/1611.08024\n",
    "\n",
    "Please reach out to me if you spot an error.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_score\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p>Here's the description from the paper</p>\n",
    "<img src=\"EEGNet.png\" style=\"width: 700px; float:left;\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable containing:\n",
      " 0.7338\n",
      "[torch.cuda.FloatTensor of size 1x1 (GPU 0)]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "class EEGNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(EEGNet, self).__init__()\n",
    "        self.T = 120\n",
    "        \n",
    "        # Layer 1\n",
    "        self.conv1 = nn.Conv2d(1, 16, (1, 64), padding = 0)\n",
    "        self.batchnorm1 = nn.BatchNorm2d(16, False)\n",
    "        \n",
    "        # Layer 2\n",
    "        self.padding1 = nn.ZeroPad2d((16, 17, 0, 1))\n",
    "        self.conv2 = nn.Conv2d(1, 4, (2, 32))\n",
    "        self.batchnorm2 = nn.BatchNorm2d(4, False)\n",
    "        self.pooling2 = nn.MaxPool2d(2, 4)\n",
    "        \n",
    "        # Layer 3\n",
    "        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))\n",
    "        self.conv3 = nn.Conv2d(4, 4, (8, 4))\n",
    "        self.batchnorm3 = nn.BatchNorm2d(4, False)\n",
    "        self.pooling3 = nn.MaxPool2d((2, 4))\n",
    "        \n",
    "        # FC Layer\n",
    "        # NOTE: This dimension will depend on the number of timestamps per sample in your data.\n",
    "        # I have 120 timepoints. \n",
    "        self.fc1 = nn.Linear(4*2*7, 1)\n",
    "        \n",
    "\n",
    "    def forward(self, x):\n",
    "        # Layer 1\n",
    "        x = F.elu(self.conv1(x))\n",
    "        x = self.batchnorm1(x)\n",
    "        x = F.dropout(x, 0.25)\n",
    "        x = x.permute(0, 3, 1, 2)\n",
    "        \n",
    "        # Layer 2\n",
    "        x = self.padding1(x)\n",
    "        x = F.elu(self.conv2(x))\n",
    "        x = self.batchnorm2(x)\n",
    "        x = F.dropout(x, 0.25)\n",
    "        x = self.pooling2(x)\n",
    "        \n",
    "        # Layer 3\n",
    "        x = self.padding2(x)\n",
    "        x = F.elu(self.conv3(x))\n",
    "        x = self.batchnorm3(x)\n",
    "        x = F.dropout(x, 0.25)\n",
    "        x = self.pooling3(x)\n",
    "        \n",
    "        # FC Layer\n",
    "        x = x.view(-1, 4*2*7)\n",
    "        x = F.sigmoid(self.fc1(x))\n",
    "        return x\n",
    "\n",
    "\n",
    "net = EEGNet().cuda(0)\n",
    "print net.forward(Variable(torch.Tensor(np.random.rand(1, 1, 120, 64)).cuda(0)))\n",
    "criterion = nn.BCELoss()\n",
    "optimizer = optim.Adam(net.parameters())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Evaluate function returns values of different criteria like accuracy, precision etc. \n",
    "In case you face memory overflow issues, use batch size to control how many samples get evaluated at one time. Use a batch_size that is a factor of length of samples. This ensures that you won't miss any samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def evaluate(model, X, Y, params = [\"acc\"]):\n",
    "    results = []\n",
    "    batch_size = 100\n",
    "    \n",
    "    predicted = []\n",
    "    \n",
    "    for i in range(len(X)/batch_size):\n",
    "        s = i*batch_size\n",
    "        e = i*batch_size+batch_size\n",
    "        \n",
    "        inputs = Variable(torch.from_numpy(X[s:e]).cuda(0))\n",
    "        pred = model(inputs)\n",
    "        \n",
    "        predicted.append(pred.data.cpu().numpy())\n",
    "        \n",
    "        \n",
    "    inputs = Variable(torch.from_numpy(X).cuda(0))\n",
    "    predicted = model(inputs)\n",
    "    \n",
    "    predicted = predicted.data.cpu().numpy()\n",
    "    \n",
    "    for param in params:\n",
    "        if param == 'acc':\n",
    "            results.append(accuracy_score(Y, np.round(predicted)))\n",
    "        if param == \"auc\":\n",
    "            results.append(roc_auc_score(Y, predicted))\n",
    "        if param == \"recall\":\n",
    "            results.append(recall_score(Y, np.round(predicted)))\n",
    "        if param == \"precision\":\n",
    "            results.append(precision_score(Y, np.round(predicted)))\n",
    "        if param == \"fmeasure\":\n",
    "            precision = precision_score(Y, np.round(predicted))\n",
    "            recall = recall_score(Y, np.round(predicted))\n",
    "            results.append(2*precision*recall/ (precision+recall))\n",
    "    return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Generate random data\n",
    "\n",
    "##### Data format:\n",
    "Datatype - float32 (both X and Y) <br>\n",
    "X.shape - (#samples, 1, #timepoints,  #channels) <br>\n",
    "Y.shape - (#samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_train = np.random.rand(100, 1, 120, 64).astype('float32') # np.random.rand generates between [0, 1)\n",
    "y_train = np.round(np.random.rand(100).astype('float32')) # binary data, so we round it to 0 or 1.\n",
    "\n",
    "X_val = np.random.rand(100, 1, 120, 64).astype('float32')\n",
    "y_val = np.round(np.random.rand(100).astype('float32'))\n",
    "\n",
    "X_test = np.random.rand(100, 1, 120, 64).astype('float32')\n",
    "y_test = np.round(np.random.rand(100).astype('float32'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Epoch  0\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.54113572836\n",
      "Train -  [0.54000000000000004, 0.59178743961352653, 0.70129870129870131]\n",
      "Validation -  [0.51000000000000001, 0.48539415766306526, 0.67549668874172186]\n",
      "Test -  [0.5, 0.50319999999999998, 0.66666666666666663]\n",
      "\n",
      "Epoch  1\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.42391115427\n",
      "Train -  [0.54000000000000004, 0.63888888888888895, 0.70129870129870131]\n",
      "Validation -  [0.51000000000000001, 0.47458983593437376, 0.67549668874172186]\n",
      "Test -  [0.5, 0.50439999999999996, 0.66666666666666663]\n",
      "\n",
      "Epoch  2\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.3422973156\n",
      "Train -  [0.55000000000000004, 0.67995169082125606, 0.70198675496688734]\n",
      "Validation -  [0.53000000000000003, 0.46898759503801518, 0.68456375838926176]\n",
      "Test -  [0.51000000000000001, 0.50800000000000001, 0.67114093959731547]\n",
      "\n",
      "Epoch  3\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.28801095486\n",
      "Train -  [0.63, 0.71054750402576483, 0.73758865248226957]\n",
      "Validation -  [0.48999999999999999, 0.4601840736294518, 0.63309352517985618]\n",
      "Test -  [0.52000000000000002, 0.51000000000000001, 0.66666666666666663]\n",
      "\n",
      "Epoch  4\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.25420039892\n",
      "Train -  [0.68999999999999995, 0.74476650563607083, 0.75590551181102361]\n",
      "Validation -  [0.42999999999999999, 0.44457783113245297, 0.53658536585365846]\n",
      "Test -  [0.51000000000000001, 0.51000000000000001, 0.6080000000000001]\n",
      "\n",
      "Epoch  5\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.22989922762\n",
      "Train -  [0.75, 0.77375201288244766, 0.77876106194690276]\n",
      "Validation -  [0.46000000000000002, 0.43937575030012005, 0.49056603773584906]\n",
      "Test -  [0.51000000000000001, 0.50440000000000007, 0.55855855855855863]\n",
      "\n",
      "Epoch  6\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.20727479458\n",
      "Train -  [0.76000000000000001, 0.79227053140096615, 0.7735849056603773]\n",
      "Validation -  [0.40999999999999998, 0.43897559023609439, 0.40404040404040403]\n",
      "Test -  [0.47999999999999998, 0.496, 0.5]\n",
      "\n",
      "Epoch  7\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.18265104294\n",
      "Train -  [0.76000000000000001, 0.81119162640901765, 0.7735849056603773]\n",
      "Validation -  [0.40999999999999998, 0.43897559023609445, 0.40404040404040403]\n",
      "Test -  [0.44, 0.48999999999999999, 0.44]\n",
      "\n",
      "Epoch  8\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.15454357862\n",
      "Train -  [0.80000000000000004, 0.8248792270531401, 0.81132075471698106]\n",
      "Validation -  [0.40000000000000002, 0.43497398959583838, 0.40000000000000002]\n",
      "Test -  [0.48999999999999999, 0.48560000000000003, 0.51428571428571423]\n",
      "\n",
      "Epoch  9\n",
      "['acc', 'auc', 'fmeasure']\n",
      "Training Loss  1.12422537804\n",
      "Train -  [0.81000000000000005, 0.83816425120772953, 0.82568807339449546]\n",
      "Validation -  [0.40999999999999998, 0.43177270908363347, 0.41584158415841577]\n",
      "Test -  [0.47999999999999998, 0.4768, 0.52727272727272734]\n"
     ]
    }
   ],
   "source": [
    "batch_size = 32\n",
    "\n",
    "for epoch in range(10):  # loop over the dataset multiple times\n",
    "    print \"\\nEpoch \", epoch\n",
    "    \n",
    "    running_loss = 0.0\n",
    "    for i in range(len(X_train)/batch_size-1):\n",
    "        s = i*batch_size\n",
    "        e = i*batch_size+batch_size\n",
    "        \n",
    "        inputs = torch.from_numpy(X_train[s:e])\n",
    "        labels = torch.FloatTensor(np.array([y_train[s:e]]).T*1.0)\n",
    "        \n",
    "        # wrap them in Variable\n",
    "        inputs, labels = Variable(inputs.cuda(0)), Variable(labels.cuda(0))\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        \n",
    "        \n",
    "        optimizer.step()\n",
    "        \n",
    "        running_loss += loss.data[0]\n",
    "    \n",
    "    # Validation accuracy\n",
    "    params = [\"acc\", \"auc\", \"fmeasure\"]\n",
    "    print params\n",
    "    print \"Training Loss \", running_loss\n",
    "    print \"Train - \", evaluate(net, X_train, y_train, params)\n",
    "    print \"Validation - \", evaluate(net, X_val, y_val, params)\n",
    "    print \"Test - \", evaluate(net, X_test, y_test, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
